'''
@File    :   base_model.py
@Time    :   2021/10/01 22:40:33
@Author  :   Ming Ding
@Contact :   dm18@mails.tsinghua.edu.cn
'''

import argparse
import inspect
import math
import os
import random
import sys
import warnings
# here put the import lib
from functools import partial

import torch
from base_transformer import GCBaseTransformer

from sat.arguments import (overwrite_args_by_dict, reset_random_seed,
                           set_random_seed, update_args_with_file)
from sat.helpers import print_rank0
from sat.model.mixins import BaseMixin
from sat.model.registry import MetaModel, model_registry
from sat.model.transformer import standard_attention
from sat.mpu.initialize import (destroy_model_parallel,
                                get_model_parallel_rank, get_node_rank,
                                initialize_model_parallel)
from sat.mpu.operation import (mp_merge_model_rank0, mp_merge_model_send,
                               mp_split_model_rank0, mp_split_model_receive)
from sat.resources import auto_create
from sat.training.model_io import load_checkpoint
from sat.transformer_defaults import ARGS_DEFAULT, HOOKS_DEFAULT


class BaseModel(torch.nn.Module, metaclass=MetaModel):

    def __init__(self,
                 args,
                 transformer=None,
                 params_dtype=torch.float,
                 **kwargs):
        super().__init__()
        self.mixins = torch.nn.ModuleDict()
        self.collect_hooks_()
        if transformer is not None:
            self.transformer = transformer
        else:
            # check if model-only mode
            from sat.arguments import _simple_init
            success = _simple_init(
                model_parallel_size=args.model_parallel_size,
                seed=args.seed if hasattr(args, 'seed') else 1234)

            args_dict = {
                k: (getattr(args, v[0]) if hasattr(args, v[0]) else v[1])
                for k, v in ARGS_DEFAULT.items()
            }

            self.transformer = GCBaseTransformer(
                num_layers=args.num_layers,
                vocab_size=args.vocab_size,
                hidden_size=args.hidden_size,
                num_attention_heads=args.num_attention_heads,
                max_sequence_length=args.max_sequence_length,
                layernorm_order=args.layernorm_order,
                **args_dict,
                hooks=self.hooks,
                params_dtype=params_dtype,
                skip_init=args.skip_init,
                device=torch.cuda.current_device()
                if hasattr(args, 'use_gpu_initialization')
                and args.use_gpu_initialization else torch.device('cpu'),
                **kwargs)

    def reinit(self,
               mixin_names=None
               ):  # will be called when loading model, None means all
        # if some mixins are loaded, overrides this function
        for k, m in self.mixins.items():
            if mixin_names is None or k in mixin_names:
                m.reinit(self)

    def add_mixin(self, name, new_mixin, reinit=False):
        assert name not in self.mixins
        assert isinstance(new_mixin, BaseMixin)

        self.mixins[name] = new_mixin  # will auto-register parameters
        object.__setattr__(new_mixin, 'transformer',
                           self.transformer)  # cannot use pytorch set_attr

        self.collect_hooks_()
        if reinit:
            new_mixin.reinit(self)  # also pass current mixins

    def del_mixin(self, name):
        assert name in self.mixins
        del self.mixins[name]
        self.collect_hooks_()

    def get_mixin(self, name):
        return self.mixins[name]

    def forward(self, *args, **kwargs):
        # update hooks as the current model (overrided forwards)
        # Attention! the transformer might be shared by multiple models
        self.transformer.hooks.clear()
        self.transformer.hooks.update(self.hooks)
        return self.transformer(*args, **kwargs)

    def collect_hooks_(self):
        names = list(HOOKS_DEFAULT.keys())
        hooks = {}
        hook_origins = {}
        for name in names:
            if hasattr(self, name):
                hooks[name] = getattr(self, name)
                hook_origins[name] = 'model'

            for mixin_name, m in self.mixins.items():
                if hasattr(m, name):
                    if hasattr(getattr(m, name), 'non_conflict'):
                        # check getattr(m, name), who must accept old_impl as an argument
                        signature = inspect.signature(getattr(m, name))
                        if 'old_impl' not in signature.parameters:
                            raise ValueError(
                                f'Hook {name} at {mixin_name} must accept old_impl as an argument.'
                            )
                        # -------------
                        if name in hooks:
                            old_impl = hooks[name]
                        elif name == 'attention_fn':  # the only hook without self
                            old_impl = HOOKS_DEFAULT[name]
                        else:
                            old_impl = partial(
                                HOOKS_DEFAULT[name], self
                            )  # relax! `partial` does not affect the signature
                        old_origin = hook_origins.get(name, 'default')
                        hooks[name] = partial(getattr(m, name),
                                              old_impl=old_impl)
                        hook_origins[name] = mixin_name + ' -> ' + old_origin
                    elif name in hooks and not hasattr(
                            hooks[name], 'replacable'
                    ):  # if this hook name is already registered
                        raise ValueError(
                            f'Hook {name} conflicts at {mixin_name} and {hook_origins[name]}.'
                        )
                    else:  # new hook
                        if name in hooks and hasattr(hooks[name],
                                                     'replacable'):
                            warnings.warn(
                                f'Hook {name} at {mixin_name} replaces {hook_origins[name]}.'
                            )
                        hooks[name] = getattr(m, name)
                        hook_origins[name] = mixin_name

        self.hooks = hooks
        self.hook_origins = hook_origins
        return hooks

    def disable_untrainable_params(self):
        pass

    @classmethod
    def add_model_specific_args(cls, parser):
        # recorded in arguments.py: add_model_config_args
        return parser

    @classmethod
    def from_pretrained_base(cls,
                             name,
                             args=None,
                             *,
                             home_path=None,
                             url=None,
                             prefix='',
                             build_only=False,
                             overwrite_args={},
                             **kwargs):
        '''Load a pretrained checkpoint of the current model.
            Args:
                name: The identifier of the pretrained model.
                args: NameSpace. will add the loaded args into it. None will create a new model-only one with defaults.
                path: the parent folder of existing `name` model. Default: SAT_HOME.
                url: the url of the model. Default: SAT_URL.
                prefix: the prefix of the checkpoint. Default: ''.
            Returns:
                model: the loaded model.
                args: the loaded args.
        '''
        if os.path.exists(name) and os.path.isdir(name):
            model_path = name
        else:
            model_path = auto_create(name, path=home_path, url=url)
        # create a new args if not provided
        if args is None:
            args = cls.get_args()
        args = update_args_with_file(args,
                                     path=os.path.join(model_path,
                                                       'model_config.json'))
        args = overwrite_args_by_dict(args, overwrite_args=overwrite_args)
        specific_iteration = kwargs.pop('specific_iteration', None)
        model = get_model(args, cls, **kwargs)
        if not build_only:
            load_checkpoint(model,
                            args,
                            load_path=model_path,
                            prefix=prefix,
                            specific_iteration=specific_iteration)
        return model, args

    @classmethod
    def from_pretrained(cls,
                        name,
                        args=None,
                        *,
                        home_path=None,
                        url=None,
                        prefix='',
                        build_only=False,
                        use_node_group=True,
                        overwrite_args={},
                        **kwargs):
        if build_only or 'model_parallel_size' not in overwrite_args:
            return cls.from_pretrained_base(name,
                                            args=args,
                                            home_path=home_path,
                                            url=url,
                                            prefix=prefix,
                                            build_only=build_only,
                                            overwrite_args=overwrite_args,
                                            **kwargs)
        else:
            new_model_parallel_size = overwrite_args['model_parallel_size']
            if new_model_parallel_size != 1 or new_model_parallel_size == 1 and args.model_parallel_size == 1:
                model, model_args = cls.from_pretrained_base(
                    name,
                    args=args,
                    home_path=home_path,
                    url=url,
                    prefix=prefix,
                    build_only=True,
                    overwrite_args=overwrite_args,
                    **kwargs)
                local_rank = get_node_rank(
                ) if use_node_group else get_model_parallel_rank()
                world_size = torch.distributed.get_world_size()
                assert world_size % new_model_parallel_size == 0, 'world size should be a multiplier of new model_parallel_size.'
                destroy_model_parallel()
                initialize_model_parallel(1)
                if local_rank == 0:
                    args.skip_init = True
                    args.use_gpu_initialization = False
                    args.device = 'cpu'
                    overwrite_args.pop('model_parallel_size')
                    model_full, args_ = cls.from_pretrained_base(
                        name,
                        args=args,
                        home_path=home_path,
                        url=url,
                        prefix=prefix,
                        build_only=False,
                        overwrite_args=overwrite_args,
                        **kwargs)
                    if args_.model_parallel_size != 1:
                        raise Exception(
                            "We do not support overwriting model_parallel_size when original model_parallel_size != 1. Try merging the model using `from_pretrained(xxx,overwrite_args={'model_parallel_size':1})` first if you still want to change model_parallel_size!"
                        )
                if hasattr(
                        args, 'mode'
                ) and args.mode == 'inference':  # For multi-node inference, we should prevent rank 0 eagerly printing some info.
                    torch.distributed.barrier()
                destroy_model_parallel()
                initialize_model_parallel(new_model_parallel_size)
                if local_rank == 0:
                    mp_split_model_rank0(model,
                                         model_full,
                                         use_node_group=use_node_group)
                    del model_full
                else:
                    mp_split_model_receive(model,
                                           use_node_group=use_node_group)
                reset_random_seed(6)
            else:
                overwrite_args.pop('model_parallel_size')
                model, model_args = cls.from_pretrained_base(
                    name,
                    args=args,
                    home_path=home_path,
                    url=url,
                    prefix=prefix,
                    build_only=False,
                    overwrite_args=overwrite_args,
                    **kwargs)
                rank = torch.distributed.get_rank()
                world_size = torch.distributed.get_world_size()
                assert world_size == model_args.model_parallel_size, 'world size should be equal to model_parallel_size.'
                destroy_model_parallel()
                initialize_model_parallel(1)
                if rank == 0:
                    args.use_gpu_initialization = False
                    args.device = 'cpu'
                    overwrite_args['model_parallel_size'] = 1
                    model_full, args_ = cls.from_pretrained_base(
                        name,
                        args=args,
                        home_path=home_path,
                        url=url,
                        prefix=prefix,
                        build_only=True,
                        overwrite_args=overwrite_args,
                        **kwargs)
                torch.distributed.barrier()
                destroy_model_parallel()
                initialize_model_parallel(model_args.model_parallel_size)
                if rank == 0:
                    mp_merge_model_rank0(model, model_full)
                    model, model_args = model_full, args_
                else:
                    mp_merge_model_send(model)
                    model_args.model_parallel_size = 1
                destroy_model_parallel()
                initialize_model_parallel(1)
            return model, model_args

    @classmethod
    def list_avail_args(cls, print=True):
        '''List all available args of the current model.'''
        parser = argparse.ArgumentParser()
        from sat.arguments import add_model_config_args
        add_model_config_args(parser)
        # add args of the current model
        if hasattr(cls, 'add_model_specific_args'):
            cls.add_model_specific_args(parser)
        if print:
            from sat.helpers import print_parser
            print_parser(parser)
        return parser

    @classmethod
    def get_args(cls, **kwargs):
        '''Get the parsed args of the current model.
            Args:
                **kwargs: will override the default args.
            Returns:
                args: the parsed args.
        '''
        parser = cls.list_avail_args(print=False)
        # use parser to parse kwargs
        args = parser.parse_args([])
        for k, v in kwargs.items():
            if hasattr(args, k) or k in [
                    'fp16'
            ]:  # non-arch args but affect building models
                setattr(args, k, v)
            else:
                print_rank0(
                    f'warning: Unknown arg {k} for class {cls.__name__}.',
                    level='DEBUG')
                setattr(args, k, v)
        return args
